// File: hand.cpp
//
// Description: Handwriting recognition system.
//              This demo uses the TNeuralNet Class
//

#include "applib.h"
#include "TNeural.h"
#include "int_vect.h"   // Use smart vector to store points
#include "double_vect.h"   // Use smart vector to store points

const int MAX_EXEMPLARS = 10;  // Number of characters to train
const int YSIZE = 7;  // Height of 2D training characters
const int XSIZE = 5;  // Width of 2D training characters
const int SIZE_2D = YSIZE * XSIZE;

int nump = 0;    // count the number of points stored
int_vect xp(200), yp(200);

//  USER DEFINED CALLBACK(S):

// Keep track of current mode of operation:
//  0:  do nothing
//  1:  recording training data
//  2:  currently training network
//  3:  testing recognition

static int current_mode = 0;

// Data for neural network: XSIZExYSIZE input,
// MAX_EXEMPLARS output neurons

TNeuralNet nnet(SIZE_2D,4,MAX_EXEMPLARS); // SIZE_2D inputs,
                                          // 6 hidden,
                                          // MAX_EXEMPLARS
                                          // output neurons

double_vect input_examples(2500),
            output_examples(800);
static int num_training_examples = 0;
static int training_input_counter = 0;
static int training_output_counter = 0;

double in_vect[SIZE_2D];

char *symbols[] = {"0","1","2","3","4","5","6",
                   "7","8","9"};

// Utility to convert captured points that were
// drawn with the mouse into a small SIZE_2D 
// element vector (SIZE_2D = YSIZExXSIZE):

void convert_to_2D(TAppWindow *tw)
{
    int min_x = 9999;
    int max_x = -1;
    int min_y = 9999;
    int max_y = -1;
    for (int i=0; i<nump; i++) 
    {
        if (min_x > xp[i]) min_x = xp[i];
        if (max_x < xp[i]) max_x = xp[i];
        if (min_y > yp[i]) min_y = yp[i];
        if (max_y < yp[i]) max_y = yp[i];
    }
    if (min_x >= max_x) min_x = max_x - 1;
    if (min_y >= max_y) min_y = max_y - 1;
    for (i=0; i<SIZE_2D; i++) in_vect[i] = 0.1;
    for (i=0; i<nump; i++) 
    {
        int x = (int) (((float)(xp[i] - min_x) /
                 (float)(max_x - min_x)) * XSIZE);
        int y = (int) (((float)(yp[i] - min_y) /
                 (float)(max_y - min_y)) * YSIZE);
        if (x < 0) x = 0;
        if (x > (XSIZE-1)) x = (XSIZE-1);
        if (y < 0) y = 0; 
        if (y > (YSIZE-1)) y = (YSIZE-1);
        in_vect[x*YSIZE+y] = 0.9;
//      tw->plot_line(10+x,70+y,10+x,70+y);
    }
}

static int training_step = 0;

static void save_training_data(TAppWindow *tw)
{
    convert_to_2D(tw);
    for (int i=0; i<SIZE_2D; i++)
    {
        input_examples[training_input_counter++] =
            in_vect[i];
    }
    for (i=0; i<MAX_EXEMPLARS; i++)
    {
        if (i != training_step)
            output_examples[training_output_counter++] = 0.1;
        else
            output_examples[training_output_counter++] = 0.9;
    }
    num_training_examples++;
}

static int recall_value = 0;

void recall_pattern(TAppWindow *tw)
{
    convert_to_2D(tw);
    nnet.forward_pass(in_vect);
    // Find the largest output activation:
    double max_act = -9999.0;
    int max_index = 0;
    for (int i=1; i<MAX_EXEMPLARS; i++)
    {
        if (nnet.output_activations[i] >
            nnet.output_activations[max_index])
        {
            max_index = i;
        }
    }
    recall_value = max_index;
}

static char sym_prompt[20];

void TAppWindow::idle_proc() { }

void TAppWindow::update_display()
{
    if (current_mode == 1) {
        plot_string(2,15,"Recording");
    }
    if (current_mode == 2) {
        plot_string(2,15,"Training");
    }
    if (current_mode == 3) {
        plot_string(2,15,"Testing");
    }
    plot_string(2,33,sym_prompt);
    for (int i=0; i<nump; i++)
        plot_line(xp[i],yp[i],xp[i]+1,yp[i]+1);
}

static int down_flag = 0;

void TAppWindow::mouse_down(int x, int y)
{
    plot_line(10,10,100,100);
    down_flag = 1;
    if (current_mode == 1 || current_mode == 3) 
    {
        if (x < 30 && y < 30) 
        {
            if (current_mode == 1)
            {
                save_training_data(this);
                nump = 0;
                for (long delay=0; delay<456000; delay++) ;
                clear_display();
                training_step += 1;
                if (training_step < MAX_EXEMPLARS) 
                {
                    nump = 0;
                    clear_display();
                    sprintf(sym_prompt,
                            "%s",
                            symbols[training_step]);
                }
                else current_mode = 0;
                return;
            }
            if (current_mode == 3) 
            {
                recall_pattern(this);
                nump = 0;
                current_mode = 0;
                clear_display();
                sprintf(sym_prompt,
                        "%s",
                        symbols[recall_value]);
                return;
            }
        }
        xp[nump] = x;
        yp[nump] = y;
        nump++;
        plot_line(x,y,x+1,y+1);
    }
}

void TAppWindow::mouse_up(int, int)
{
    down_flag = 0;
}

void TAppWindow::mouse_move(int x, int y)
{
    if (down_flag)  mouse_down(x, y);
}

void TAppWindow::do_menu_action(int item_number)
{
    if (item_number == 1) // Record training data
    {
        current_mode = 1;
        training_step = 0;
        nump = 0;
        clear_display();
            sprintf(sym_prompt,"%s",symbols[0]);
    }
    if (item_number == 2) // Train network and save weights
    {
        for (int j=0; j<4; j++)
        {
            sprintf(sym_prompt,
                    "Starting training iteration %d, num_training_examples=%d",
                    j,
                    num_training_examples);
            clear_display();
            double error = 
                nnet.auto_train(50, // 50 should be 1000
                                num_training_examples,
                                &(input_examples[0]),
                                &(output_examples[0]));
            sprintf(sym_prompt,"Error=%f",error);
            clear_display();
        }
        nnet.save("test.dat");
        current_mode = 2;
    }
    if (item_number == 3) // Reload saved weights
    {
        nnet.restore("test.dat");
    }
    if (item_number == 4) // test recognition
    {
        current_mode = 3;
    }
}


static char *menu_items[] =
    {"Enter data recording mode",
     "Train network & save weights",
     "Reload saved weights",
     "Test recognition"};

INIT_PROGRAM("Hand", 4, menu_items)
   // anything can go here
   sym_prompt[0] = '\0';
   RUN_PROGRAM;
}
